import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import inception_v3
from torchvision import datasets
mnist_dataset_train = datasets.MNIST(root='./data', train=True, download=True)

def preprocess_images(images):
    """Preprocess a batch of images for InceptionV3."""
    transform = get_transform()  # Reuse the transform function from previous steps
    processed_images = []
    for img_array in images:
        img = Image.fromarray((img_array * 255).astype(np.uint8))  # Convert to PIL image
        img = transform(img)
        processed_images.append(img)
    return torch.stack(processed_images)

def get_transform():
    """Get the preprocessing transform for InceptionV3."""
    return transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


def get_inception_model():
    """Load the InceptionV3 model for feature extraction."""
    model = inception_v3(pretrained=True, transform_input=False)
    model.eval()  # Set the model to evaluation mode
    return model

def extract_features(images, model):
    """Extract features from a tensor of images."""
    with torch.no_grad():  # Disable gradient calculation
        features = model(images).cpu().numpy()
    return features


from scipy.linalg import sqrtm

def calculate_fid_score(features1, features2):
    """Calculate the Fréchet Inception Distance (FID) score."""
    mu1, sigma1 = np.mean(features1, axis=0), np.cov(features1, rowvar=False)
    mu2, sigma2 = np.mean(features2, axis=0), np.cov(features2, rowvar=False)

    mu_diff = mu1 - mu2
    cov_mean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(cov_mean):
        cov_mean = cov_mean.real + 1e-6 * cov_mean.imag
    fid = mu_diff @ mu_diff + np.trace(sigma1 + sigma2 - 2 * cov_mean)
    return fid


train_data_real = (mnist_dataset_train.train_data/255.).numpy()
Train_100_image, Train_100_y = torch.load('GAN_100_train.pth').values()
Train_300_image, Train_100_y = torch.load('GAN_300_train.pth').values()
Train_500_image, Train_100_y = torch.load('GAN_500_train.pth').values()
inception_model = get_inception_model()
Set_real, Set_syn_100, Set_syn_300, Set_syn_500 = [], [], [], []
# Preprocess images
for i in range(60):
    Num_1, Num_2 = i * 1000, (i+1) * 1000
    set1_images_tensor = preprocess_images(train_data_real[Num_1:Num_2])
    set2_images_tensor = preprocess_images(Train_100_image.view(100000,28,28).numpy()[Num_1:Num_2])
    set3_images_tensor = preprocess_images(Train_300_image.view(100000,28,28).numpy()[Num_1:Num_2])
    set4_images_tensor = preprocess_images(Train_500_image.view(100000,28,28).numpy()[Num_1:Num_2])
    set1_features = extract_features(set1_images_tensor, inception_model)
    set2_features = extract_features(set2_images_tensor, inception_model)
    set3_features = extract_features(set3_images_tensor, inception_model)
    set4_features = extract_features(set4_images_tensor, inception_model)
    Set_real.append(set1_features)
    Set_syn_100.append(set2_features)
    Set_syn_300.append(set3_features)
    Set_syn_500.append(set4_features)
    print(i)
Set_real_1, Set_syn_1 = np.concatenate(Set_real), np.concatenate(Set_syn_100)
Set_syn_3, Set_syn_5 = np.concatenate(Set_syn_300), np.concatenate(Set_syn_500)
np.save('Set_real',Set_real_1)
np.save('Set_syn_100',Set_syn_1)
np.save('Set_syn_300',Set_syn_3)
np.save('Set_syn_500',Set_syn_5)
# Compute FID score
fid_score_1 = calculate_fid_score(Set_real_1, Set_syn_1)
fid_score_3 = calculate_fid_score(Set_real_1, Set_syn_3)
fid_score_5 = calculate_fid_score(Set_real_1, Set_syn_5)
print(fid_score_1,fid_score_3,fid_score_5)
